package cn.doitedu.datayi.etl

import cn.doitedu.datayi.utils.BitmapSerDe
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.roaringbitmap.RoaringBitmap
import cn.doitedu.datayi.utils.BitmapSerDe.{de, ser}


object BitmapOrUDAF extends UserDefinedAggregateFunction{

  // 定义函数的输入参数的结构（字段名：字段类型）
  override def inputSchema: StructType = StructType(Seq(StructField("bm",DataTypes.BinaryType)))

  // 定义存储中间计算结果的结构（字段名：字段类型）
  override def bufferSchema: StructType = StructType(Seq(StructField("buff",DataTypes.BinaryType)))

  // 定义函数的返回值字段类型
  override def dataType: DataType = DataTypes.BinaryType

  // 定义函数是否是“相同输入”永远得到“相同输出”（涉及到sql优化时的策略选择）
  override def deterministic: Boolean = true

  // 初始化中间结果缓存结构
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    val bitmap: RoaringBitmap = RoaringBitmap.bitmapOf()
    val bytes: Array[Byte] = ser(bitmap)

    buffer.update(0,bytes)
  }

  // 来一条数据，更新一次中间结果
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    val inputBytes: Array[Byte] = input.getAs[Array[Byte]](0)
    val bufferBytes: Array[Byte] = buffer.getAs[Array[Byte]](0)

    // 反序列化输入的bitmap
    val inputBitmap: RoaringBitmap = de(inputBytes)

    // 反序列化缓存中的bitmap
    val bufferBitmap: RoaringBitmap = de(bufferBytes)

    // 合并两个bitmap，并将合并后的结果序列化，更新到缓存
    bufferBitmap.or(inputBitmap)
    buffer.update(0,ser(bufferBitmap))

  }

  // 将多个局部聚合结果合并成一个全局聚合结果
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    update(buffer1,buffer2)
  }

  // 从全局聚合结果中，返回最终函数输出值
  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Array[Byte]](0)
  }

}
